import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from scipy.stats import norm, lognorm
import pickle
import h5py
import random
import shutil
import os, glob, re


# input directory to access the csv files which were interpolated in prevoius step and the name should be adjusted according to the dataset(5D,7D,9D)
input_directory ='./processed_data_5D/'

# load all csv files from the directory
all_files = [f for f in os.listdir(input_directory) if f.endswith('.csv')]

# print the number of files
print('Number of files:', len(all_files))

# Base order (5d has first 5, 7d adds mu/mux, 9d adds PL/PR)
candidate_cols = ['t_p', 'Voltage', 'h', 'C', 'δ_value', 'mu', 'mux', 'PL', 'PR']

# Detect available columns from the first CSV file
first_file_path = os.path.join(input_directory, all_files[0])
sample_df = pd.read_csv(first_file_path)

# Keep only columns that exist (preserves order)
columns_to_scale = [c for c in candidate_cols if c in sample_df.columns]


# Step 1: Extract maximum current values I_max from all I(t) sequences
def extract_max_currents(data_folder):
    I_max_list = []
    for file in os.listdir(data_folder):
        if file.endswith(".csv"):
            file_path = os.path.join(data_folder, file)
            data = pd.read_csv(file_path)
            I_t = data['Current']  # Assuming the column name for current is 'Current'
            I_max = I_t.max()
            I_max_list.append(I_max)
    return np.array(I_max_list)


# Step 2: Fit a log-normal distribution to the I_max histogram
def fit_lognormal_distribution(I_max_list):
    shape, loc, scale = lognorm.fit(I_max_list, floc=0)
    mu_N = scale * np.exp((shape**2) / 2)
    sigma_N = scale * np.sqrt((np.exp(shape**2) - 1) * np.exp(shape**2))
    return mu_N, sigma_N

# Step 3: Standardize currents I_(max,N)
def standardize_currents(I_max_list, mu_N, sigma_N):
    log_I_max = np.log10(I_max_list)
    I_max_N = (log_I_max - mu_N) / sigma_N
    return I_max_N

# Step 4: Scale current vectors I_s
def scale_current_vectors(data_folder, I_max_N, I_max_list):
    I_s_list = []
    for idx, file in enumerate(os.listdir(data_folder)):
        if file.endswith(".csv"):
            file_path = os.path.join(data_folder, file)
            data = pd.read_csv(file_path)
            I_t = data['Current']
            I_s = I_t * (I_max_N[idx] / I_max_list[idx])
            I_s_list.append(I_s)
    return I_s_list


# Step 5: Fit a normal distribution to I_(s,max)
def fit_normal_distribution(I_s_max_list):
    mu_s, sigma_s = norm.fit(I_s_max_list)
    return mu_s, sigma_s


# Step 6: Normalize currents I_n
def normalize_currents(I_s_list, mu_s):
    I_n_list = [I_s / mu_s for I_s in I_s_list]
    I_n_list = [I_n * -1 for I_n in I_n_list]
    return I_n_list

# Step 7: Scale all parameters and normalize currents
def scale_and_normalize_parameters(input_directory, columns_to_scale):
    # Collect all values for global MinMax scaling
    all_values = {col: [] for col in columns_to_scale}
    for filename in os.listdir(input_directory):
        if filename.endswith('.csv'):
            file_path = os.path.join(input_directory, filename)
            data = pd.read_csv(file_path)
            for col in columns_to_scale:
                if col in data.columns:
                    all_values[col].extend(data[col].values)

    # Fit MinMaxScaler globally for each column
    scalers = {col: MinMaxScaler().fit(np.array(all_values[col]).reshape(-1, 1)) for col in columns_to_scale}

    # Extract and process the Current column
    I_max_list = extract_max_currents(input_directory)
    mu_N, sigma_N = fit_lognormal_distribution(I_max_list)
    I_max_N = standardize_currents(I_max_list, mu_N, sigma_N)
    I_s_list = scale_current_vectors(input_directory, I_max_N, I_max_list)
    I_s_max_list = [I_s.max() for I_s in I_s_list]
    mu_s, sigma_s = fit_normal_distribution(I_s_max_list)
    I_n_list = normalize_currents(I_s_list, mu_s)


    # Scale the parameters globally and replace the Current column with normalized values
    scaled_data = {}
    for idx, filename in enumerate(os.listdir(input_directory)):
        if filename.endswith('.csv'):
            file_path = os.path.join(input_directory, filename)
            data = pd.read_csv(file_path)

            # Keep the Time column as is
            time_column = data['Time'].values.reshape(-1, 1)

            # Scale the specified columns globally
            scaled_columns = {}
            for col in columns_to_scale:
                if col in data.columns:
                    scaled_columns[col] = scalers[col].transform(data[[col]])
                else:
                    print(f"'{col}' column not found in {filename}")

            # Combine Time, scaled columns, and normalized Current
            scaled_array = np.hstack([time_column] + 
                                      [scaled_columns[col] for col in columns_to_scale if col in scaled_columns] + 
                                      [I_n_list[idx].values.reshape(-1, 1)])  # Add normalized Current
            scaled_data[filename] = scaled_array

            print(f"Processed: {filename}")

    print("All files have been processed and the specified columns have been scaled globally.")
    return scaled_data, (mu_N, sigma_N, mu_s)

# Call the function to scale and normalize parameters
scaled_parameters_data, (mu_N, sigma_N, mu_s) = scale_and_normalize_parameters(input_directory, columns_to_scale)


# Example: Access the scaled data for a specific file
example_file = list(scaled_parameters_data.keys())[0]
scaled_data = scaled_parameters_data[example_file]

# Load the original data for the same file
original_file_path = os.path.join(input_directory, example_file)
original_data = pd.read_csv(original_file_path)

# Check if the number of data points is the same
original_data_points = original_data.shape[0]
scaled_data_points = scaled_data.shape[0]

print(f"Number of data points in the original file ({example_file}): {original_data_points}")
print(f"Number of data points in the scaled file ({example_file}): {scaled_data_points}")

# Verify if the number of data points is consistent
if original_data_points == scaled_data_points:
    print("The number of data points is consistent before and after scaling and normalizing.")
else:
    print("WARNING: The number of data points has changed after scaling and normalizing!")


def save_hdf5(file_path, data):
    with h5py.File(file_path, 'w') as f:
        for i, sample in enumerate(data):
            g = f.create_group(f'sample_{i}')
            g.attrs['filename'] = sample['filename']
            g.create_dataset('Time', data=sample['Time'])
            g.create_dataset('scaled_parameters', data=sample['scaled_parameters'])
            g.create_dataset('normalized_current', data=sample['normalized_current'])
    print(f"Data saved to {file_path}")


# Prepare data for splitting
data = []
for filename, scaled_array in scaled_parameters_data.items():
    data.append({
        'filename': filename,
        'Time': scaled_array[:, 0],  # Time is the first column
        'scaled_parameters': scaled_array[:, 1:-1],  # All columns except Time and Current
        'normalized_current': scaled_array[:, -1]   # The last column (normalized current)
    })

# Split data into train and test sets
train_data, test_data = train_test_split(data, test_size=0.1, random_state=42)

# Save train and test sets to separate HDF5 files. the name could be adjusted according to the dataset(5D,7D,9D)

train_h5_path = 'train_dataset_2_t_p_20000_5D.h5'
test_h5_path = 'test_dataset_2_t_p_20000_5D.h5'

save_hdf5(train_h5_path, train_data)
save_hdf5(test_h5_path, test_data)






